
from Network.network_utils import run_optimizer
from ActualCausal.Train.train_utils import compute_likelihood
from ActualCausal.Train.regularizers import apply_regularizers

def train_passive(args, params, model, buffer, single=False, name="", log_batch=[], wrap_function=None, additional=[], itr_num=0, intermediate_logger = None):
    mod_type = "single_passive" if single else "passive"
    for i in range(args.passive.steps):
        batch, idxes = buffer.sample(args.train.batch_size, params.sample_passive_weights)
        batch = wrap_function(batch) if wrap_function is not None else batch
        result = model.infer(batch, batch.valid, [mod_type],log_batch=log_batch, additional=additional) # adds logits [batch, num_obj*obj_dim]
        # below is not needed because it is already computed
        # result = compute_likelihood(args, result[mod_type], batch, model, name) # adds target, dist, done_flags, log_probs, loss_log_prob
        grad_variables = [result.passive_input] if args.passive.include_gradient else list()
        compute_models, optims = model.get_model_optim([mod_type])
        optim, compute_model = optims[0], compute_models[0]
        # determine which regularizers not to use
        skip_names = list()
        if args.inter.regularization.splitting.splitting_passive: skip_names.append("splitting")
        if args.inter.regularization.expt_passive: skip_names.append("expectile")
        loss = apply_regularizers(- result[mod_type].log_probs, args, params, model, batch, results=result[mod_type], skip_names=skip_names)
        result.gradients = run_optimizer(optim, compute_model, loss, grad_variables=grad_variables)
        if intermediate_logger is not None: intermediate_logger.log(itr_num * args.passive.steps + i, {"passive": result}, intermediate_name = "_passive")
    return result
